package com.itcast.rpc.client.runner;

import com.itcast.common.data.RpcRequest;
import com.itcast.common.data.RpcResponse;
import com.itcast.common.utils.RpcException;
import com.itcast.common.utils.SpringBeanFactory;
import com.itcast.common.utils.StatusEnum;
import com.itcast.rpc.client.cache.ServiceRouteCache;
import com.itcast.rpc.client.channel.ChannelHolder;
import com.itcast.rpc.client.channel.ProviderService;
import com.itcast.rpc.client.cluster.ClusterStrategy;
import com.itcast.rpc.client.config.RpcClientConfiguration;
import com.itcast.rpc.client.connector.RpcClientConnector;
import com.itcast.rpc.client.connector.RpcClientInitializer;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.*;

/**
 * Rpc请求管理器
 */
public class RpcRequestManager {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcRequestManager.class);

    private static ServiceRouteCache SERVICE_ROUTE_CACHE;

    private static RpcClientInitializer rpcClientInitializer;

    public static void startRpcRequestManager(ServiceRouteCache serviceRouteCache) {
        SERVICE_ROUTE_CACHE = serviceRouteCache;
        rpcClientInitializer = SpringBeanFactory.getBean(RpcClientInitializer.class);
    }

    private static ClusterStrategy clusterStrategy;

    private static RequestPool REQUEST_POOL;

    public static void startRpcRequestManager(ServiceRouteCache serviceRouteCache, RpcClientConfiguration rpcClientConfiguration,
                                                RequestPool requestPool) {
        SERVICE_ROUTE_CACHE = serviceRouteCache;
        rpcClientInitializer = SpringBeanFactory.getBean(RpcClientInitializer.class);
        // 存在策略对象->SpringIoc->SpringBeanFactory
        clusterStrategy = SpringBeanFactory.getBean(rpcClientConfiguration.getRpcClientClusterStrategy());
        REQUEST_POOL = requestPool;
    }

    private static ConcurrentHashMap<String,ChannelHolder> channelHolderHashMap = new ConcurrentHashMap<>();

    // 绑定
    public static void registerChannelHolder(String requestId,ChannelHolder channelHolder){
        channelHolderHashMap.put(requestId,channelHolder);
        REQUEST_POOL.submitRequest(requestId,channelHolder.getChannel().eventLoop());
    }

    // 解除绑定
    public static void destoryChannelHolder(String requestId){
        ChannelHolder channelHolder = channelHolderHashMap.remove(requestId);
        channelHolder.getChannel().closeFuture();
        channelHolder.getEventLoopGroup().shutdownGracefully();
    }

    private static final ExecutorService REQUEST_EXECUTOR = new ThreadPoolExecutor(
            30,
            100,
            0,
            TimeUnit.SECONDS,
            new ArrayBlockingQueue<>(30),
            new BasicThreadFactory.Builder().namingPattern("request-service-connector-%d").build()
    );


    /**
     * 异步发送客户端请求
     * @param rpcRequest
     * @throws InterruptedException
     * @throws RpcException
     */
    public static void sendRequestAsync(RpcRequest rpcRequest) throws InterruptedException, RpcException {
        // 1. 从缓存中获取RPC服务列表信息
        List<ProviderService> providerServices = SERVICE_ROUTE_CACHE.getServiceRoutes(rpcRequest.getClassName());
        // 2. 从服务列表中获取第一个服务信息
//        ProviderService targetServiceProvider = providerServices.get(0);
        ProviderService targetServiceProvider = clusterStrategy.select(providerServices);
        String requestId = rpcRequest.getRequestId();
        CountDownLatch latch = new CountDownLatch(1);
        REQUEST_EXECUTOR.execute(new RpcClientConnector(requestId,targetServiceProvider,latch));
        // 等待计数器
        latch.await();
        ChannelHolder channelHolder = channelHolderHashMap.get(requestId);
        channelHolder.getChannel().writeAndFlush(rpcRequest);

    }




    /**
     * 发送客户端请求
     * @param rpcRequest
     * @throws InterruptedException
     * @throws RpcException
     */
    public static RpcResponse sendRequest(RpcRequest rpcRequest) throws InterruptedException, RpcException {
        // 1. 从缓存中获取RPC服务列表信息
        List<ProviderService> providerServices = SERVICE_ROUTE_CACHE.getServiceRoutes(rpcRequest.getClassName());
        // 2. 从服务列表中获取第一个服务信息
//        ProviderService targetServiceProvider = providerServices.get(0);
        ProviderService targetServiceProvider = clusterStrategy.select(providerServices);

        if (targetServiceProvider != null) {
            String requestId = rpcRequest.getRequestId();
            // 3. 发起远程调用
            RpcResponse response = requestByNetty(rpcRequest, targetServiceProvider);
            LOGGER.info("Send request[{}:{}] to service provider successfully", requestId, rpcRequest.toString());
            return response;
        } else {
            throw new RpcException(StatusEnum.NOT_FOUND_SERVICE_PROVINDER);
        }
    }

    /**
     * 采用Netty进行远程调用
     */
    public static RpcResponse requestByNetty(RpcRequest rpcRequest, ProviderService providerService) {

        // 1. 创建Netty连接配置
        EventLoopGroup worker = new NioEventLoopGroup();
        Bootstrap bootstrap = new Bootstrap();
        bootstrap.group(worker)
                .channel(NioSocketChannel.class)
                .remoteAddress(providerService.getServerIp(), providerService.getNetworkPort())
                .handler(rpcClientInitializer);
        try {
            // 2. 建立连接
            ChannelFuture future = bootstrap.connect().sync();
            if (future.isSuccess()) {
                ChannelHolder channelHolder = ChannelHolder.builder()
                        .channel(future.channel())
                        .eventLoopGroup(worker)
                        .build();
                LOGGER.info("Construct a connector with service provider[{}:{}] successfully",
                        providerService.getServerIp(),
                        providerService.getNetworkPort()
                );

                // 3. 创建请求回调对象
                final RequestFuture<RpcResponse> responseFuture = new SyncRequestFuture(rpcRequest.getRequestId());
                // 4. 将请求回调放置缓存
                SyncRequestFuture.syncRequest.put(rpcRequest.getRequestId(), responseFuture);
                // 5. 根据连接通道， 下发请求信息
                ChannelFuture channelFuture = channelHolder.getChannel().writeAndFlush(rpcRequest);
                // 6. 建立回调监听
                channelFuture.addListener(new ChannelFutureListener() {
                    @Override
                    public void operationComplete(ChannelFuture future) throws Exception {
                        // 7. 设置是否成功的标记
                        responseFuture.setWriteResult(future.isSuccess());
                        if(!future.isSuccess()) {
                            // 调用失败，清除连接缓存
                            SyncRequestFuture.syncRequest.remove(responseFuture.requestId());
                        }
                    }
                });
                // 8. 阻塞等待3秒
                RpcResponse result = responseFuture.get(3, TimeUnit.SECONDS);
                // 9. 移除连接缓存
                SyncRequestFuture.syncRequest.remove(rpcRequest.getRequestId());

                return result;
            }
        } catch (Exception ex) {
            ex.printStackTrace();
        }

        return null;
    }

}
